{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5a721e6d-c72c-42a6-8f10-f7b13e17824f",
   "metadata": {},
   "source": [
    "This notebook demonstrates the procedure for training the multi-label models during the 2023 TRAM effort.\n",
    "\n",
    "Only the annotations produced during this effort can be adapted for multi-label modeling. First, we will load the annotations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f37d6d05-df9f-449b-b813-a1d03cb031a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>labels</th>\n",
       "      <th>doc_title</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>title: NotPetya Technical Analysis – A Triple ...</td>\n",
       "      <td>[]</td>\n",
       "      <td>NotPetya Technical Analysis  A Triple Threat F...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Executive Summary This technical analysis prov...</td>\n",
       "      <td>[]</td>\n",
       "      <td>NotPetya Technical Analysis  A Triple Threat F...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>For more information on CrowdStrike’s proactiv...</td>\n",
       "      <td>[]</td>\n",
       "      <td>NotPetya Technical Analysis  A Triple Threat F...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NotPetya combines ransomware with the ability ...</td>\n",
       "      <td>[]</td>\n",
       "      <td>NotPetya Technical Analysis  A Triple Threat F...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>It spreads to Microsoft Windows machines using...</td>\n",
       "      <td>[T1210]</td>\n",
       "      <td>NotPetya Technical Analysis  A Triple Threat F...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19173</th>\n",
       "      <td>[2] Eclypsium Blog - TrickBot Now Offers 'Tric...</td>\n",
       "      <td>[]</td>\n",
       "      <td>AA21076A TrickBot Malware</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19174</th>\n",
       "      <td>Initial Version March 24, 2021:</td>\n",
       "      <td>[]</td>\n",
       "      <td>AA21076A TrickBot Malware</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19175</th>\n",
       "      <td>Added MITRE ATT&amp;CK Technique T1592.003 used fo...</td>\n",
       "      <td>[]</td>\n",
       "      <td>AA21076A TrickBot Malware</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19176</th>\n",
       "      <td>Added new MITRE ATT&amp;CKs and updated Table 1</td>\n",
       "      <td>[]</td>\n",
       "      <td>AA21076A TrickBot Malware</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19177</th>\n",
       "      <td>This product is provided subject to this Notif...</td>\n",
       "      <td>[]</td>\n",
       "      <td>AA21076A TrickBot Malware</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>19178 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                sentence   labels   \n",
       "0      title: NotPetya Technical Analysis – A Triple ...       []  \\\n",
       "1      Executive Summary This technical analysis prov...       []   \n",
       "2      For more information on CrowdStrike’s proactiv...       []   \n",
       "3      NotPetya combines ransomware with the ability ...       []   \n",
       "4      It spreads to Microsoft Windows machines using...  [T1210]   \n",
       "...                                                  ...      ...   \n",
       "19173  [2] Eclypsium Blog - TrickBot Now Offers 'Tric...       []   \n",
       "19174                    Initial Version March 24, 2021:       []   \n",
       "19175  Added MITRE ATT&CK Technique T1592.003 used fo...       []   \n",
       "19176       Added new MITRE ATT&CKs and updated Table 1        []   \n",
       "19177  This product is provided subject to this Notif...       []   \n",
       "\n",
       "                                               doc_title  \n",
       "0      NotPetya Technical Analysis  A Triple Threat F...  \n",
       "1      NotPetya Technical Analysis  A Triple Threat F...  \n",
       "2      NotPetya Technical Analysis  A Triple Threat F...  \n",
       "3      NotPetya Technical Analysis  A Triple Threat F...  \n",
       "4      NotPetya Technical Analysis  A Triple Threat F...  \n",
       "...                                                  ...  \n",
       "19173                          AA21076A TrickBot Malware  \n",
       "19174                          AA21076A TrickBot Malware  \n",
       "19175                          AA21076A TrickBot Malware  \n",
       "19176                          AA21076A TrickBot Malware  \n",
       "19177                          AA21076A TrickBot Malware  \n",
       "\n",
       "[19178 rows x 3 columns]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data = pd.read_json('../data/tram2-data/multi_label.json')\n",
    "all_labels = data['labels'].explode().dropna().unique()\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1988180-c3fe-446f-9332-1eacd7d6d378",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "import torch\n",
    "\n",
    "mode: 'bert or gpt' = 'bert'\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "if mode == 'bert':\n",
    "    model = transformers.BertForSequenceClassification.from_pretrained(\n",
    "        \"allenai/scibert_scivocab_uncased\",\n",
    "        num_labels=len(all_labels),\n",
    "        output_attentions=False,\n",
    "        output_hidden_states=False,\n",
    "    )\n",
    "    tokenizer = transformers.BertTokenizer.from_pretrained(\"allenai/scibert_scivocab_uncased\", max_length=512)\n",
    "elif mode == 'gpt':\n",
    "    model = transformers.GPT2ForSequenceClassification.from_pretrained(\n",
    "        \"gpt2\",\n",
    "        num_labels=len(all_labels),\n",
    "        output_attentions=False,\n",
    "        output_hidden_states=False,\n",
    "    )\n",
    "    tokenizer = transformers.GPT2Tokenizer.from_pretrained(\"gpt2\", max_length=512)\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "else:\n",
    "    raise ValueError(f\"mode must be one of bert or gpt, but is {mode = !r}\")\n",
    "\n",
    "model.train().to(cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96dba89b-cc19-4e07-be02-e8b1627a2664",
   "metadata": {},
   "source": [
    "For single-label modeling, we represented the labels using one hot encoding. The representation here is similar, except each row can have more than one `1` if it represents a multi-label instance. Some rows will not have any `1`s if they represent a negative sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bfe68106-7b2d-445f-a0cf-e684b7421a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MultiLabelBinarizer as MLB\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "mlb = MLB()\n",
    "mlb.fit([[c] for c in all_labels])\n",
    "\n",
    "train, test = train_test_split(data, test_size=0.2, shuffle=True)\n",
    "\n",
    "def load_data(x, y, batch_size=10):\n",
    "    x_len, y_len = x.shape[0], y.shape[0]\n",
    "    assert x_len == y_len\n",
    "    for i in range(0, x_len, batch_size):\n",
    "        slc = slice(i, i + batch_size)\n",
    "        yield x[slc].to(cuda), y[slc].to(cuda)\n",
    "\n",
    "def tokenize(instances: 'list[str]'):\n",
    "    return tokenizer(instances, return_tensors='pt', padding='max_length', truncation=True, max_length=512).input_ids\n",
    "\n",
    "def encode_labels(labels):\n",
    "    \"\"\":labels: should be the `labels` column (a Series) of the DataFrame\"\"\"\n",
    "    return torch.Tensor(mlb.transform(labels.to_numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d096515b-da21-4761-8d74-88b9085dbb4c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  102,  7208,  4531,  ...,     0,     0,     0],\n",
       "        [  102,   260, 24391,  ...,     0,     0,     0],\n",
       "        [  102,  4975, 11554,  ...,     0,     0,     0],\n",
       "        ...,\n",
       "        [  102,   407, 14382,  ...,     0,     0,     0],\n",
       "        [  102,   256,   165,  ...,     0,     0,     0],\n",
       "        [  102,   111,  2057,  ...,     0,     0,     0]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train = tokenize(train['sentence'].tolist())\n",
    "x_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "81ea37c7-9abc-4866-ab1e-af5b37df5986",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.,  ..., 1., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train = encode_labels(train['labels'])\n",
    "y_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3af62a5a-0f60-405c-9b60-c5f6784d2881",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from statistics import mean\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.optim import AdamW\n",
    "\n",
    "optim = AdamW(model.parameters(), lr=2e-5, eps=1e-8)\n",
    "\n",
    "for epoch in range(6):\n",
    "    epoch_losses = []\n",
    "    for x, y in tqdm(load_data(x_train, y_train, batch_size=10)):\n",
    "        model.zero_grad()\n",
    "        out = model(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int), labels=y)\n",
    "        epoch_losses.append(out.loss.item())\n",
    "        out.loss.backward()\n",
    "        optim.step()\n",
    "    print(f\"epoch {epoch + 1} loss: {mean(epoch_losses)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5dcd4c-5d20-4392-bbef-9a2938fbbf58",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_fscore_support as calculate_score\n",
    "\n",
    "model.eval()\n",
    "\n",
    "x_test = tokenize(test['sentence'].tolist())\n",
    "\n",
    "batch_size = 20\n",
    "preds = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(0, x_test.shape[0], batch_size):\n",
    "        x = x_test[i : i + batch_size].to(cuda)\n",
    "        out = model(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int))\n",
    "        preds.extend(out.logits.to('cpu'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ede06fc4-bff4-4f1c-a384-6eff4ad8be6c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "binary_preds = torch.vstack(preds).sigmoid().gt(.5).to(int)\n",
    "\n",
    "predicted = pd.Series(mlb.inverse_transform(binary_preds)).apply(frozenset).reset_index(drop=True)\n",
    "actual = test['labels'].apply(frozenset).reset_index(drop=True)\n",
    "results = pd.concat({'preds': predicted, 'actual': actual}, axis=1)\n",
    "\n",
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5748c29e-109c-4757-916c-e9dc3c23a803",
   "metadata": {},
   "source": [
    "While the formulae for precision, recall, and F1 are the same for multi-label evaluation, the procedure for counting true positives, false positives, and false negatives is not.\n",
    "\n",
    "Where $P$ is the set of predicted labels for a given instance, and $A$ is the set of actual labels for the same instance, the true positives are the labels in $P \\cap A$, the false positives are those in $P - A$, and the false negatives are those in $A - P$.\n",
    "\n",
    "To give an example, if the actual labels for a sample are $\\{a, c\\}$, and the model predicts $\\{c, d\\}$, that is a true positive for $c$, a false positive for $d$, and a false negative for $a$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d08b662e-d1e4-4662-b156-07185ebdd04b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tp = results.apply((lambda r: r.preds & r.actual), axis=1).explode().value_counts()\n",
    "fp = results.apply((lambda r: r.preds - r.actual), axis=1).explode().value_counts()\n",
    "fn = results.apply((lambda r: r.actual - r.preds), axis=1).explode().value_counts()\n",
    "\n",
    "support = actual.explode().value_counts().rename('#')\n",
    "\n",
    "counts = pd.concat({'tp': tp, 'fp': fp, 'fn': fn}, axis=1).fillna(0).astype(int)\n",
    "\n",
    "p = counts.tp.div(counts.tp + counts.fp).fillna(0)\n",
    "r = counts.tp.div(counts.tp + counts.fn).fillna(0)\n",
    "f1 = (2 * p * r) / (p + r)\n",
    "scores = pd.concat({'P': p, 'R': r, 'F1': f1}, axis=1).fillna(0).sort_values(by='F1', ascending=False)\n",
    "\n",
    "# calculate macro scores\n",
    "scores.loc['(macro)'] = scores.mean()\n",
    "\n",
    "# calculate micro scores\n",
    "micro = counts.sum()\n",
    "scores.loc['(micro)', 'P'] = mP = micro.tp / (micro.tp + micro.fp)\n",
    "scores.loc['(micro)', 'R'] = mR = micro.tp / (micro.tp + micro.fn)\n",
    "scores.loc['(micro)', 'F1'] = (2 * mP * mR) / (mP + mR)\n",
    "\n",
    "scores.join(support)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
